/*
 * AIEngine a new generation network intrusion detection system.
 *
 * Copyright (C) 2013-2023  Luis Campo Giralte
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Library General Public
 * License as published by the Free Software Foundation; either
 * version 2 of the License, or (at your option) any later version.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Library General Public License for more details.
 *
 * You should have received a copy of the GNU Library General Public
 * License along with this library; if not, write to the
 * Free Software Foundation, Inc., 51 Franklin St, Fifth Floor,
 * Boston, MA  02110-1301, USA.
 *
 * Written by Luis Campo Giralte <luis.camp0.2009@gmail.com>
 *
 */
#include "SSLProtocol.h"
#include <iomanip> // setw

namespace aiengine {

SSLProtocol::~SSLProtocol() {

	anomaly_.reset();
}

bool SSLProtocol::check(const Packet &packet) {

        int length = packet.getLength();

        if (length >= header_size) {
		const uint8_t *payload = packet.getPayload();
		if ((payload[1] == 0x03)and(payload[0] >= SSL3_CT_CHANGE_CIPHER_SPEC)and(payload[0] <= SSL3_CT_APPLICATION_DATA)) {
			setHeader(payload);
			++total_valid_packets_;
			return true;
		}
	}
	++total_invalid_packets_;
	return false;
}

void SSLProtocol::setDynamicAllocatedMemory(bool value) {

	info_cache_->setDynamicAllocatedMemory(value);
	host_cache_->setDynamicAllocatedMemory(value);
	issuer_cache_->setDynamicAllocatedMemory(value);
	session_cache_->setDynamicAllocatedMemory(value);
#if defined(HAVE_JA3)
	ja3_cache_->setDynamicAllocatedMemory(value);
#endif
}

bool SSLProtocol::isDynamicAllocatedMemory() const {

	return info_cache_->isDynamicAllocatedMemory();
}

uint64_t SSLProtocol::getCurrentUseMemory() const {

	uint64_t mem = sizeof(SSLProtocol);

	mem += info_cache_->getCurrentUseMemory();
	mem += host_cache_->getCurrentUseMemory();
	mem += issuer_cache_->getCurrentUseMemory();
	mem += session_cache_->getCurrentUseMemory();
#if defined(HAVE_JA3)
	mem += ja3_cache_->getCurrentUseMemory();
#endif
	return mem;
}

uint64_t SSLProtocol::getAllocatedMemory() const {

        uint64_t mem = sizeof(SSLProtocol);

        mem += info_cache_->getAllocatedMemory();
        mem += host_cache_->getAllocatedMemory();
        mem += issuer_cache_->getAllocatedMemory();
        mem += session_cache_->getAllocatedMemory();
#if defined(HAVE_JA3)
	mem += ja3_cache_->getAllocatedMemory();
#endif
        return mem;
}

uint64_t SSLProtocol::getTotalAllocatedMemory() const {

	uint64_t mem = getAllocatedMemory();

	mem += compute_memory_used_by_maps();

	return mem;
}

uint64_t SSLProtocol::compute_memory_used_by_maps() const {

	uint64_t bytes = 0;

        std::for_each (host_map_.begin(), host_map_.end(), [&bytes] (PairStringCacheHits const &ht) {
        	bytes += ht.first.size();
	});
        std::for_each (issuer_map_.begin(), issuer_map_.end(), [&bytes] (PairStringCacheHits const &ht) {
        	bytes += ht.first.size();
	});
        std::for_each (session_map_.begin(), session_map_.end(), [&bytes] (PairStringCacheHits const &ht) {
        	bytes += ht.first.size();
	});
#if defined(HAVE_JA3)
        std::for_each (ja3_map_.begin(), ja3_map_.end(), [&bytes] (PairStringCacheHits const &ht) {
        	bytes += ht.first.size();
	});
#endif
	return bytes;
}

uint32_t SSLProtocol::getTotalCacheMisses() const {

	uint32_t miss = 0;

	miss = info_cache_->getTotalFails();
	miss += host_cache_->getTotalFails();
	miss += issuer_cache_->getTotalFails();
	miss += session_cache_->getTotalFails();
#if defined(HAVE_JA3)
	miss += ja3_cache_->getTotalFails();
#endif
	return miss;
}

void SSLProtocol::releaseCache() {

        if (FlowManagerPtr fm = flow_mng_.lock(); fm) {
                auto ft = fm->getFlowTable();

                std::ostringstream msg;
                msg << "Releasing " << name() << " cache";

                infoMessage(msg.str());

		int64_t total_cache_bytes_released = compute_memory_used_by_maps();
		int64_t total_cache_save_bytes = 0;
		int64_t total_bytes_released_by_flows = 0;
                int32_t release_flows = 0;
                int32_t release_hosts = host_map_.size();
                int32_t release_issuers = issuer_map_.size();
                int32_t release_sessions = session_map_.size();
#if defined(HAVE_JA3)
		int32_t release_fingerprints = ja3_map_.size();
#endif

                for (auto &flow: ft) {
		       	if (auto info = flow->getSSLInfo(); info) {
                                total_bytes_released_by_flows += sizeof(info);

                                flow->layer7info.reset();
                                info_cache_->release(info);
				++release_flows;
                        }
                }
		// Some entries can be still on the maps and needs to be
		// retrieve to their existing caches
		for (auto &entry: host_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
			releaseStringToCache(host_cache_, entry.second.sc);
		}
	        host_map_.clear();

		for (auto &entry: issuer_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
			releaseStringToCache(issuer_cache_, entry.second.sc);
		}
		issuer_map_.clear();

		for (auto &entry: session_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
			releaseStringToCache(session_cache_, entry.second.sc);
		}
		session_map_.clear();
#if defined(HAVE_JA3)
		for (auto &entry: ja3_map_) {
			total_cache_save_bytes += entry.second.sc->size() * (entry.second.hits - 1);
			releaseStringToCache(ja3_cache_, entry.second.sc);
		}
		ja3_map_.clear();
#endif
		data_time_ = boost::posix_time::microsec_clock::local_time();

        	msg.str("");
                msg << "Release " << release_hosts << " hosts, " << release_issuers << " issuers, ";
		msg << release_sessions << " sessions, ";
#if defined(HAVE_JA3)
		msg << release_fingerprints << " fingerprints, ";
#endif
		msg << release_flows << " flows";
		computeMemoryUtilization(msg, total_cache_bytes_released, total_bytes_released_by_flows, total_cache_save_bytes);
                infoMessage(msg.str());
        }
}

void SSLProtocol::releaseFlowInfo(Flow *flow) {

	if (auto info = flow->getSSLInfo(); info)
		info_cache_->release(info);
}

void SSLProtocol::attach_common_name(SSLInfo *info, const boost::string_ref &name) {

        if (!info->issuer) {
                if (auto it = issuer_map_.find(name); it != issuer_map_.end()) {
                        ++(it->second).hits;
                        info->issuer = (it->second).sc;
		} else {
                        if (auto name_ptr = issuer_cache_->acquire(); name_ptr) {
                                name_ptr->name(name.data(), name.length());
                                info->issuer = name_ptr;
                                issuer_map_.insert(std::make_pair(name_ptr->name(), name_ptr));
                        }
                }
        }
}

void SSLProtocol::attach_host(SSLInfo *info, const boost::string_ref &host) {

	if (!info->host_name) {
                if (auto it = host_map_.find(host); it != host_map_.end()) {
                        ++(it->second).hits;
                        info->host_name = (it->second).sc;
		} else {
                        if (auto host_ptr = host_cache_->acquire(); host_ptr) {
                                host_ptr->name(host.data(), host.length());
                                info->host_name = host_ptr;
                                host_map_.insert(std::make_pair(host_ptr->name(), host_ptr));
                        }
                }
        }
}

void SSLProtocol::attach_session(SSLInfo *info, const boost::string_ref &session) {

        if (!info->session_id) {
                if (auto it = session_map_.find(session); it != session_map_.end()) {
                        ++(it->second).hits;
                        info->session_id = (it->second).sc;
                } else {
                        if (auto session_ptr = session_cache_->acquire(); session_ptr) {
                                session_ptr->name(session.data(), session.length());
                                info->session_id = session_ptr;
                                session_map_.insert(std::make_pair(session_ptr->name(), session_ptr));
                        }
                }
        }
}

#if defined(HAVE_JA3)
void SSLProtocol::attach_ja3_fingerprint(SSLInfo *info, const boost::string_ref &fingerprint) {

        if (!info->ja3_fingerprint) {
                if (auto it = ja3_map_.find(fingerprint); it != ja3_map_.end()) {
                        ++(it->second).hits;
                        info->ja3_fingerprint = (it->second).sc;
		} else {
                        if (auto ja3_ptr = ja3_cache_->acquire(); ja3_ptr) {
                                ja3_ptr->name(fingerprint.data(), fingerprint.length());
                                info->ja3_fingerprint = ja3_ptr;
                                ja3_map_.insert(std::make_pair(ja3_ptr->name(), ja3_ptr));
                        }
                }
        }
}
#endif

void SSLProtocol::handle_client_hello(SSLInfo *info, const uint8_t *data, int length) {

#if defined(DEBUG)
	std::cout << __FILE__ << ":" << __func__ << ":length:" << std::dec << length << std::endl;
#endif
	const ssl_hello *hello = reinterpret_cast<const ssl_hello*>(data);
	uint16_t version = ntohs(hello->version);
	int block_offset = sizeof(ssl_hello);

	++total_client_hellos_;

	if ((version >= SSL3_VERSION)and(version <= TLS1_2_VERSION)) {
		int len = ntohs(hello->length);

		if (ntohs(hello->session_id_length) > 0) { // Session id management
			block_offset += 32;
                        std::ostringstream session;

			 for (int i = 0; i < 32; ++i)
                                session << std::hex << std::setw(2) << std::setfill('0') << (int)hello->data[i];

			attach_session(info, boost::string_ref(session.str()));
		}

		uint16_t cipher_length = ntohs((data[block_offset + 1] << 8) + data[block_offset]);
		if (cipher_length < len) {
#if defined(HAVE_JA3)
			std::ostringstream ja3, ja3ext, ja3group, ja3ec;

			ja3 << version << ",";
			const uint8_t *cipher = &data[block_offset + 2];
			int i = 0;
			for (i = 0 ; i < cipher_length - 2; i += 2)
				ja3 << ntohs((cipher[i + 1] << 8) + cipher[i]) << "-";

			ja3 << ntohs((cipher[i + 1] << 8) + cipher[i]);

#endif
			block_offset += cipher_length  + 2;
			const uint8_t *compression_pointer = &data[block_offset];
			uint8_t compression_length = compression_pointer[0];

			if (compression_length > 0)
				block_offset += (compression_length + 1);

			if (block_offset < len) {
				const uint8_t *extensions = &data[block_offset];
				uint16_t extensions_length __attribute__((unused)) = ((extensions[0] << 8) + extensions[1]);

				block_offset += 2;
				while (block_offset < length) {
					const ssl_extension *extension = reinterpret_cast<const ssl_extension*>(&data[block_offset]);
#if defined(HAVE_JA3)
					ja3ext << ntohs((data[block_offset + 1] << 8) + data[block_offset]) << "-";
#endif
					switch (extension->type) {
						case 0x0000: { // Server name
							const ssl_server_name *server = reinterpret_cast<const ssl_server_name*>(&extension->data[0]);
							int server_length = ntohs(server->length);

							if ((block_offset + server_length < len)and(server_length > 0)) {
								boost::string_ref servername((char*)server->data, server_length);

								if (ban_domain_mng_) {
									if (auto host_candidate = ban_domain_mng_->getDomainName(servername); host_candidate) {
										++total_ban_hosts_;
										info->setIsBanned(true);
										return;
									}
								}
								++total_allow_hosts_;

								attach_host(info, servername);
							}
							break;
						}
						case 0x01FF: // Renegotiation
							break;
						case 0x0F00: // heartbeat
							info->setHeartbeat(true);
							break;
						case 0x000D: // Signature algorithm
							break;
						case 0x2300: // Session ticket
							break;
#if defined(HAVE_JA3)
						case 0x0A00: { // Groups
							int group_length = ntohs((extension->data[1] << 8) + extension->data[0]);
							const uint8_t *group = &extension->data[2];
                        				for (int i = 0 ; i < group_length; i += 2)
                                				ja3group << ntohs((group[i + 1] << 8) + group[i]) << "-";

							break;
						}
						case 0x0B00: { // Eliptic curve format
							const uint8_t *ec_format = &extension->data[1];
							const uint8_t ec_len = extension->data[0];
							int i;
							for (i = 0; i < ec_len - 1; ++i)
								ja3ec << (int)ec_format[i] << "-";
							ja3ec << (int)ec_format[i];
							break;
						}
#endif
						case 0xCEFF: {
							boost::string_ref servername((char*)esni_label.data(), esni_label.length());
							attach_host(info, servername);
						}
					}
					block_offset += ntohs(extension->length) + sizeof(ssl_extension);
				}
			}
#if defined(HAVE_JA3)
			// Parse the generated signature
			MD5_CTX md5;
			uint8_t hash[16];
			std::ostringstream digest;
			std::string ja3signature(ja3.str());

			ja3signature.append(",");
			ja3signature.append(ja3ext.str(), 0, ja3ext.str().length() - 1);
			ja3signature.append(",");
			ja3signature.append(ja3group.str(), 0, ja3group.str().length() - 1);
			ja3signature.append(",");
			ja3signature.append(ja3ec.str());

			// Compute the MD5 of the generated signature
			MD5_Init(&md5);
			MD5_Update(&md5, ja3signature.c_str(), ja3signature.length());
			MD5_Final(hash, &md5);

			for (int i = 0; i < (int)sizeof(hash); ++i)
				digest << std::hex << std::setw(2) << std::setfill('0') << (int)hash[i];

			// Attach generated signature if needed
			attach_ja3_fingerprint(info, digest.str());
#endif
		}
	} // end version
	return;
}

void SSLProtocol::handle_server_hello(SSLInfo *info, const uint8_t *data, int length) {

#ifdef DEBUG
	std::cout << __FILE__ << ":" << __func__ << ":length:" << std::dec << length << std::endl;
#endif
	const ssl_hello *hello __attribute__((unused)) = reinterpret_cast<const ssl_hello*>(data);
        uint16_t version = ntohs(hello->version);
        int block_offset = sizeof(ssl_hello);

	++total_server_hellos_;

        if ((version >= SSL3_VERSION)and(version <= TLS1_2_VERSION)) {
                [[maybe_unused]] int len = ntohs(hello->length);

                if (ntohs(hello->session_id_length) > 0) // Session id management
                        block_offset += 32;

		uint16_t cipher_id = ntohs((data[block_offset + 1] << 8) + data[block_offset]);
		info->setCipher(cipher_id);
	} else if (hello->version == 0x0E7F) { // This is TLS1.3 draft
		info->setVersion(TLS1_3_VERSION);
	}
}

uint8_t SSLProtocol::get_asn1_length(uint8_t byte) {

	uint8_t value = byte;

	value &= ~(1 << 7);
	if (byte & (1 << 7)) { // The bit 8 is set
		value &= ~(1 << 6);
	}

	return value;
}

void SSLProtocol::handle_issuer_certificate(SSLInfo *info, const uint8_t *data, int length) {

	// Handle the ASN1 issuer component
	const uint8_t *ptr = &data[4];
	int off = 0;

#ifdef DEBUG
	std::cout << __FILE__ << ":" <<  __func__ << ":len:" << std::dec << length << std::endl;
        showPayload(std::cout, data, length - 8);
#endif
	while (off < length - (8 + 9)) { // one of the items
		if (ptr[0] == 0x06) { // Object identifier
			if ((ptr[2] == 0x55)and(ptr[3] == 0x04)and(ptr[4] == 0x03)) { // id-at-commonName
				uint8_t atype = ptr[5];
				if ((atype == 0x13)or(atype == 0x14)or(atype == 0x0C)) {
					// PrintableString(0x13), teletextString(0x14), DirectorySTring(0x0C)
					uint8_t alen = get_asn1_length(ptr[6]);
					boost::string_ref name((char*)&ptr[7], alen);
#ifdef DEBUG
					std::cout << __FILE__ << ":" <<  __func__ << ":commonName:" << name << std::endl;
#endif
					attach_common_name(info, name);
					return;
				}
			}
		}
		++ptr;
		++off;
	}
}

void SSLProtocol::handle_certificate(SSLInfo *info, const uint8_t *data, int length) {

#ifdef DEBUG
	std::cout << __FILE__ << ":" << __func__ << ":len:" << std::dec << length << std::endl;
#endif
	const ssl_cert *record = reinterpret_cast<const ssl_cert*>(data);
        [[maybe_unused]] uint16_t type = record->type;
	[[maybe_unused]] int16_t len = ntohs(record->cert_length);

	/* The cert is encode with ASN1 DER format from &data[sizeof(ssl_cert)] */
	++total_certificates_;
	const uint8_t *ptr = &data[sizeof(ssl_cert)];

#ifdef DEBUG
	std::cout << "CERT Payload, rlen:" << len << " len:" << length << " atype:" << std::hex << (short)ptr[0] << "\n";
	// showPayload(std::cout, data, length);
#endif
	uint8_t alen = get_asn1_length(ptr[12]);
	uint8_t atype = ptr[11];

	if (atype == 0xA0) { // Enumerated of the items
		ptr = &ptr[11 + alen + 2];

		atype = ptr[0];
		alen = get_asn1_length(ptr[1]);
		if (atype == 0x02) { // serialNumber integer
			ptr = &ptr[alen + 2];

			atype = ptr[0];
			if (atype == 0x30) { // signature
				handle_issuer_certificate(info, ptr, length - (ptr - data - sizeof(ssl_cert)));
			}
		}
	}
}

void SSLProtocol::handle_handshake(SSLInfo *info, const ssl_record *record, int length) {

	uint16_t version = ntohs(record->version);
	uint8_t type = record->data[0];
	int record_length = ntohs(record->length);

#ifdef DEBUG
	std::cout << __FILE__ << ":" << __func__ << ":Container:len:" << length;
	std::cout << " rlen:" << record_length << " type:" << int(type) << " version:" << version  << std::endl;
	showPayload(std::cout, (const uint8_t*)record, length);
#endif
	if ((version >= SSL3_VERSION)and(version <= TLS1_2_VERSION)) {
		info->setVersion(version);

		// This is a valid SSL header that we could extract some usefulll information.
		// SSL Records are group by blocks

		int max_records = 0;
		int offset = 0;
		const uint8_t *ssl_data = record->data;

		if (length < (int)sizeof(handshake_record)) {
                       	++total_events_;
                        if (current_flow_->getPacketAnomaly() == PacketAnomalyType::NONE)
                               	current_flow_->setPacketAnomaly(PacketAnomalyType::SSL_BOGUS_HEADER);

                        anomaly_->incAnomaly(current_flow_, PacketAnomalyType::SSL_BOGUS_HEADER);
			return;
		}

		do {
			const handshake_record *hsk_record = reinterpret_cast<const handshake_record*>(&ssl_data[offset]);
			record_length = ntohs(hsk_record->length);
			type = hsk_record->type;
#ifdef DEBUG
			std::cout << __FILE__ << ":" << __func__ << ":record:len:" << length << " rlen:" << record_length;
			std::cout << " type:" << int(type) << " offset:" << offset << std::endl;
#endif
			++max_records;

			if ((record_length > length)and(type != SSL3_MT_CERTIFICATE)) {
                        	++total_events_;
                                if (current_flow_->getPacketAnomaly() == PacketAnomalyType::NONE)
                                	current_flow_->setPacketAnomaly(PacketAnomalyType::SSL_BOGUS_HEADER);

                                anomaly_->incAnomaly(current_flow_, PacketAnomalyType::SSL_BOGUS_HEADER);
				return;
			}

			if (type == SSL3_MT_CLIENT_HELLO) {
				handle_client_hello(info, ssl_data, length);
			} else if (type == SSL3_MT_SERVER_HELLO)  {
				handle_server_hello(info, ssl_data, length);
			} else if (type == SSL3_MT_CERTIFICATE) {
				handle_certificate(info, &ssl_data[offset], length - offset);
				if (current_flow_->getFlowDirection() == FlowDirection::FORWARD)
					info->setMtls(true);
			} else if (type == SSL3_MT_SERVER_KEY_EXCHANGE) {
				++total_server_key_exchanges_;
			} else if (type == SSL3_MT_NEW_SESSION_TICKET) {
				++total_new_session_tickets_;
			} else if (type == SSL3_MT_CERTIFICATE_REQUEST) {
				++total_certificate_requests_;
			} else if (type == SSL3_MT_SERVER_DONE) {
				++total_server_dones_;
			} else if (type == SSL3_MT_CERTIFICATE_VERIFY) {
				++total_certificate_verifies_;
			} else if (type == SSL3_MT_CLIENT_KEY_EXCHANGE) {
				++total_client_key_exchanges_;
			} else if (type >= SSL3_MT_FINISHED) { // or the data is encrypted, need to recheck this
				++total_handshake_finishes_;
			}

			offset += record_length + sizeof(handshake_record);

		} while ((offset + (int)sizeof(handshake_record) <= length) and (max_records < 3));
	}
}

void SSLProtocol::processFlow(Flow *flow) {

	CPUCycle cycles(&total_cpu_cycles_);
	++total_packets_;
	int length = flow->packet->getLength();
	total_bytes_ += length;
	++flow->total_packets_l7;

        auto info = flow->getSSLInfo();
        if (!info) {
                if (info = info_cache_->acquire(); !info) {
			logFailCache(info_cache_->name(), flow);
			return;
                }
                flow->layer7info = info;
        }

        if (info->isBanned() == false) { 

		current_flow_ = flow;

		if (length > header_size) {
			setHeader(flow->packet->getPayload());

			if ((header_->type >= SSL3_CT_CHANGE_CIPHER_SPEC)and(header_->type <= SSL3_CT_APPLICATION_DATA)) {
				int record_length = ntohs(header_->length);

				if (record_length > 0) {
					const uint8_t *payload = flow->packet->getPayload();
					int offset = 0;         // Total offset byte
					int maxattemps = 0;     // For prevent invalid decodings

					do {
						const ssl_record *record = reinterpret_cast<const ssl_record*>(&payload[offset]);
						[[maybe_unused]] uint16_t version = ntohs(record->version);
						uint8_t type = record->type;
						record_length = ntohs(record->length);
						++maxattemps;
#ifdef DEBUG
						std::cout << __FILE__ << ":" << __func__ << ":len:" << length << " rlen:" << record_length;
						std::cout << " type: " << int(type) << " offset:" << offset << std::endl;
#endif
						if (type == SSL3_CT_HANDSHAKE) {
							// There is a ssl record with the minimal length

							// Check if the record length is valid, if not truncate
							if (offset + record_length + (int)sizeof(handshake_record) > length)
								record_length = length - (offset + sizeof(handshake_record));

							// The handshake could be encrypted
							if (info->isEncrypted() == false) {
								handle_handshake(info.get(), record, record_length);
								++total_handshakes_;
							} else
								++total_encrypted_handshakes_;

						} else if (type == SSL3_CT_CHANGE_CIPHER_SPEC) {
							++total_change_cipher_specs_;
							info->setEncrypted(true); // From this point all should be encrypted
						} else if (type == SSL3_CT_APPLICATION_DATA) { // On Tls1.3 encrypted data can be sent
							++total_data_;
							info->incDataPdus();
						} else if (type == SSL3_CT_ALERT) {
							// Is an Alert messsage
							info->setAlert(true);
							++total_alerts_;
							if (record_length >= 2) { // Regular length of alerts
								int8_t value = record->data[1];
								info->setAlertCode(value);
							}
						}

						++total_records_;

						offset += record_length + sizeof(ssl_record);

						if (maxattemps == 5) break; // Maximum Pdus per packet allowed
					} while (offset + (int)sizeof(ssl_record) < length);
				}

				if ((flow->total_packets_l7 == 1)and(domain_mng_)) {
					const char *name = info->host_name ? info->host_name->name() : "";

					if (auto host_candidate = domain_mng_->getDomainName(name); host_candidate) {
						++total_events_;
						info->matched_domain_name = host_candidate;
#if defined(BINDING)
						if (host_candidate->call.haveCallback())
							host_candidate->call.executeCallback(flow);
#endif
					}
				}
			}
		}
	}
}

void SSLProtocol::setDomainNameManager(const SharedPointer<DomainNameManager> &dm) {

	if (domain_mng_)
		domain_mng_->setPluggedToName("");

	if (dm) {
		domain_mng_ = dm;
        	domain_mng_->setPluggedToName(name());
	} else {
		domain_mng_.reset();
	}
}

void SSLProtocol::statistics(std::basic_ostream<char> &out, int level, int32_t limit) const {

	showStatisticsHeader(out, level);

	if (level > 0) {
                if (ban_domain_mng_)
			out << "\t" << "Plugged banned domains from:" << ban_domain_mng_->name() << "\n";
                if (domain_mng_)
			out << "\t" << "Plugged domains from:" << domain_mng_->name() << "\n";
	}
	if (level > 3) {
		out << "\t" << "Total allow hosts:      " << std::setw(10) << total_allow_hosts_ << "\n"
			<< "\t" << "Total banned hosts:     " << std::setw(10) << total_ban_hosts_ << "\n"
			<< "\t" << "Total handshakes:       " << std::setw(10) << total_handshakes_ << "\n"
			<< "\t" << "Total encrypt handshakes:" << std::setw(9) << total_encrypted_handshakes_ << "\n"
			<< "\t" << "Total alerts:           " << std::setw(10) << total_alerts_ << "\n"
			<< "\t" << "Total change cipher specs:" << std::setw(8) << total_change_cipher_specs_ << "\n"
			<< "\t" << "Total data:             " << std::setw(10) << total_data_ << "\n"
			<< "\t" << "Total client hellos:    " << std::setw(10) << total_client_hellos_ << "\n"
			<< "\t" << "Total server hellos:    " << std::setw(10) << total_server_hellos_ << "\n"
			<< "\t" << "Total certificates:     " << std::setw(10) << total_certificates_ << "\n"
			<< "\t" << "Total server key exs:   " << std::setw(10) << total_server_key_exchanges_ << "\n"
			<< "\t" << "Total certificate reqs: " << std::setw(10) << total_certificate_requests_ << "\n"
			<< "\t" << "Total server dones:     " << std::setw(10) << total_server_dones_ << "\n"
			<< "\t" << "Total certificates vers:" << std::setw(10) << total_certificate_verifies_ << "\n"
			<< "\t" << "Total client key exs:   " << std::setw(10) << total_client_key_exchanges_ << "\n"
			<< "\t" << "Total new session tickets:" << std::setw(8) << total_new_session_tickets_ << "\n"
			<< "\t" << "Total handshakes finish:" << std::setw(10) << total_handshake_finishes_ << "\n"
			<< "\t" << "Total records:          " << std::setw(10) << total_records_ << std::endl;
	}
	if ((level > 5)and(flow_forwarder_.lock()))
		flow_forwarder_.lock()->statistics(out);
	if (level > 3) {
		info_cache_->statistics(out);
		host_cache_->statistics(out);
		issuer_cache_->statistics(out);
		session_cache_->statistics(out);
#if defined(HAVE_JA3)
		ja3_cache_->statistics(out);
#endif
		if (level > 4) {
			host_map_.show(out, "\t", limit);
			issuer_map_.show(out, "\t", limit);
			session_map_.show(out, "\t", limit);
#if defined(HAVE_JA3)
			ja3_map_.show(out, "\t", limit);
#endif
		}
	}
}

void SSLProtocol::statistics(Json &out, int level) const {

	showStatisticsHeader(out, level);

        if (level > 3) {
                out["allow hosts"] = total_allow_hosts_;
		out["banned hosts"] = total_ban_hosts_;
		out["handshakes"] = total_handshakes_;
		out["encrypt handshakes"] = total_encrypted_handshakes_;
		out["alerts"] = total_alerts_;
		out["change_cipher_specs"] = total_change_cipher_specs_;
		out["data"] = total_data_;

		Json j;

		j["client hellos"] = total_client_hellos_;
		j["server hellos"] = total_server_hellos_;
		j["certificates"] = total_certificates_;
		j["server keys"] = total_server_key_exchanges_;
		j["certificate requests"] = total_certificate_requests_;
		j["server dones"] = total_server_dones_;
		j["certificate verifies"] = total_certificate_verifies_;
		j["client keys"] = total_client_key_exchanges_;
		j["handshake finish"] = total_handshake_finishes_;
		j["new session tickets"] = total_new_session_tickets_;
		j["records"] = total_records_;

		out["types"] = j;
        }
}

void SSLProtocol::increaseAllocatedMemory(int value) {

	info_cache_->create(value);
	host_cache_->create(value);
	issuer_cache_->create(value);
	session_cache_->create(value);
#if defined(HAVE_JA3)
	ja3_cache_->create(value);
#endif
}

void SSLProtocol::decreaseAllocatedMemory(int value) {

	info_cache_->destroy(value);
	host_cache_->destroy(value);
	issuer_cache_->destroy(value);
	session_cache_->destroy(value);
#if defined(HAVE_JA3)
	ja3_cache_->destroy(value);
#endif
}

CounterMap SSLProtocol::getCounters() const {
	CounterMap cm;

	cm.addKeyValue("packets", total_packets_);
	cm.addKeyValue("bytes", total_bytes_);
	cm.addKeyValue("allow hosts", total_allow_hosts_);
	cm.addKeyValue("banned hosts", total_ban_hosts_);

	cm.addKeyValue("handshakes", total_handshakes_);
	cm.addKeyValue("encrypt handshakes", total_encrypted_handshakes_);
	cm.addKeyValue("alerts", total_alerts_);
	cm.addKeyValue("change cipher specs", total_change_cipher_specs_);
	cm.addKeyValue("datas", total_data_);

	cm.addKeyValue("client hellos", total_client_hellos_);
	cm.addKeyValue("server hellos", total_server_hellos_);
	cm.addKeyValue("certificates", total_certificates_);
	cm.addKeyValue("server key exchanges", total_server_key_exchanges_);
	cm.addKeyValue("certificate requests", total_certificate_requests_);
	cm.addKeyValue("server dones", total_server_dones_);
	cm.addKeyValue("certificate verifies", total_certificate_verifies_);
	cm.addKeyValue("client key exchanges", total_client_key_exchanges_);
	cm.addKeyValue("handshake dones", total_handshake_finishes_);
	cm.addKeyValue("new session tickets", total_new_session_tickets_);
	cm.addKeyValue("records", total_records_);

        return cm;
}

#if defined(PYTHON_BINDING) || defined(RUBY_BINDING)
#if defined(PYTHON_BINDING)
boost::python::dict SSLProtocol::getCacheData(const std::string &name) const {
#elif defined(RUBY_BINDING)
VALUE SSLProtocol::getCacheData(const std::string &name) const {
#endif
        if (boost::iequals(name, "host"))
		return addMapToHash(host_map_);
        else if (boost::iequals(name, "issuer"))
		return addMapToHash(issuer_map_);
        else if (boost::iequals(name, "session"))
		return addMapToHash(session_map_);
#if defined(HAVE_JA3)
        else if (boost::iequals(name, "ja3"))
		return addMapToHash(ja3_map_);
#endif
	StringMap empty {"", ""};

        return addMapToHash(empty);
}

#if defined(PYTHON_BINDING)
SharedPointer<Cache<StringCache>> SSLProtocol::getCache(const std::string &name) {

        if (boost::iequals(name, "host"))
                return host_cache_;
        else if (boost::iequals(name, "issuer"))
                return issuer_cache_;
        else if (boost::iequals(name, "session"))
                return session_cache_;
#if defined(HAVE_JA3)
        else if (boost::iequals(name, "ja3"))
                return ja3_cache_;
#endif
        return nullptr;
}

#endif

#endif

void SSLProtocol::statistics(Json &out, const std::string &map_name, int32_t limit) const {

	if (boost::iequals(map_name, "hosts")) {
		host_map_.show(out, limit);
		return;
	}

	if (boost::iequals(map_name, "issuers")) {
		issuer_map_.show(out, limit);
		return;
	}
	if (boost::iequals(map_name, "sessions")) {
		session_map_.show(out, limit);
		return;
	}
#if defined(HAVE_JA3)
	if (boost::iequals(map_name, "fingerprints")) {
		ja3_map_.show(out, limit);
	}
#endif
}

void SSLProtocol::resetCounters() {

	reset();

        total_events_ = 0;
        total_handshakes_ = 0;
        total_encrypted_handshakes_ = 0;
        total_alerts_ = 0;
        total_change_cipher_specs_ = 0;
        total_data_ = 0;
        total_client_hellos_ = 0;
        total_server_hellos_ = 0;
        total_certificates_ = 0;
        total_server_key_exchanges_ = 0;
        total_certificate_requests_ = 0;
        total_server_dones_ = 0;
        total_certificate_verifies_ = 0;
        total_client_key_exchanges_ = 0;
        total_handshake_finishes_ = 0;
        total_new_session_tickets_ = 0;
        total_records_ = 0;
        total_ban_hosts_ = 0;
        total_allow_hosts_ = 0;
}

} // namespace aiengine
